ECE 57000 Assignment 08 Exercise¶

Your Name: TJ Wiegman

For this assignment, you will do an ablation study on the DCGAN model discussed in class and implemented WGAN with weight clipping and (optional) WGAN with gradient penalty.

Exercise 1: Ablation Study on DCGAN¶

An ablation study measures performance changes after changing certain components in the AI system. The goal is to understand the contribution on each component for the overall system.

Task 1.0 Original DCGAN on MNIST from class note¶

Here is the copy of the code implementation from course website. Please run the code to obtain the result and use it as a baseline to compare the results with the following the ablation tasks.

Hyper-parameter and Dataloader setup¶

In [1]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmarks = False
os.environ['PYTHONHASHSEED'] = str(manualSeed)

# Root directory for dataset
# dataroot = "data/celeba"

# Number of workers for dataloader
workers = 1

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
#image_size = 64
image_size = 32

# Number of channels in the training images. For color images this is 3
#nc = 3
nc = 1 

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
#ngf = 64
ngf = 8

# Size of feature maps in discriminator
#ndf = 64
ndf = 8

# Number of training epochs
num_epochs = 5
num_epochs_wgan = 15
num_iters = 250

# Learning rate for optimizers
lr = 0.0002
lr_rms = 5e-4

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.0
fake_label = 0.0

# Several useful functions
def initialize_net(net_class, init_method, device, ngpu):

    # Create the generator
    net_inst = net_class(ngpu).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        net_inst = nn.DataParallel(net_inst, list(range(ngpu)))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    if init_method is not None:
        net_inst.apply(init_method)

    # Print the model
    print(net_inst)

    return net_inst

def plot_GAN_loss(losses, labels):

    plt.figure(figsize=(10,5))
    plt.title("Losses During Training")

    for loss, label in zip(losses, labels):
        plt.plot(loss,label=f"{label}")

    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()


def plot_real_fake_images(real_batch, fake_batch):

    # Plot the real images
    plt.figure(figsize=(15,15))
    plt.subplot(1,2,1)
    plt.axis("off")
    plt.title("Real Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

    # Plot the fake images from the last epoch
    plt.subplot(1,2,2)
    plt.axis("off")
    plt.title("Fake Images")
    plt.imshow(np.transpose(fake_batch[-1],(1,2,0)))
    plt.show()


# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Download the MNIST dataset
dataset = dset.MNIST(
    'data', train=True, download=True,
   transform=transforms.Compose([
       transforms.Resize(image_size), # Resize from 28 x 28 to 32 x 32 (so power of 2)
       transforms.CenterCrop(image_size),
       transforms.ToTensor(),
       transforms.Normalize((0.5,), (0.5,))
   ])) 

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
Random Seed:  999
Out[1]:
<matplotlib.image.AxesImage at 0x7f7ad0cbffa0>

Architectural design for generator and discriminator¶

In [2]:
# Generator Code
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution, state size. nz x 1 x 1
            nn.ConvTranspose2d( nz, ngf * 4, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True), # inplace ReLU
            # current state size. (ngf*4) x 4 x 4
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # current state size. (ngf*2) x 8 x 8
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # current state size. ngf x 16 x 16
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            # current state size. nc x 32 x 32 
            # Produce number between -1 and 1, as pixel values have been normalized to be between -1 and 1
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 32 x 32 
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 16 x 16
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 8 x 8 
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 4 x 4
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            # state size. (ndf*4) x 1 x 1
            nn.Sigmoid()  # Produce probability
        )

    def forward(self, input):
        return self.main(input)

Loss function and Training function¶

In [3]:
# Initialize networks
netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator, weights_init, device, ngpu)

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))


# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
        iters += 1
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): Tanh()
  )
)
Discriminator(
  (main): Sequential(
    (0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (9): Sigmoid()
  )
)
Starting Training Loop...
[0/5][0/469]	Loss_D: 1.4302	Loss_G: 0.7566	D(x): 0.4824	D(G(z)): 0.4978 / 0.4725
[0/5][50/469]	Loss_D: 0.5910	Loss_G: 1.3961	D(x): 0.7987	D(G(z)): 0.2983 / 0.2543
[0/5][100/469]	Loss_D: 0.2957	Loss_G: 2.1519	D(x): 0.8793	D(G(z)): 0.1484 / 0.1230
[0/5][150/469]	Loss_D: 0.1266	Loss_G: 2.8365	D(x): 0.9526	D(G(z)): 0.0740 / 0.0671
[0/5][200/469]	Loss_D: 0.0684	Loss_G: 3.4099	D(x): 0.9768	D(G(z)): 0.0435 / 0.0385
[0/5][250/469]	Loss_D: 0.0600	Loss_G: 3.4485	D(x): 0.9879	D(G(z)): 0.0464 / 0.0359
[0/5][300/469]	Loss_D: 0.0282	Loss_G: 4.4007	D(x): 0.9924	D(G(z)): 0.0203 / 0.0130
[0/5][350/469]	Loss_D: 0.0158	Loss_G: 4.6839	D(x): 0.9961	D(G(z)): 0.0119 / 0.0097
[0/5][400/469]	Loss_D: 0.0149	Loss_G: 4.9847	D(x): 0.9942	D(G(z)): 0.0091 / 0.0071
[0/5][450/469]	Loss_D: 0.0158	Loss_G: 5.1067	D(x): 0.9922	D(G(z)): 0.0080 / 0.0063
[1/5][0/469]	Loss_D: 0.0155	Loss_G: 5.1242	D(x): 0.9927	D(G(z)): 0.0081 / 0.0070
[1/5][50/469]	Loss_D: 0.0105	Loss_G: 5.3981	D(x): 0.9951	D(G(z)): 0.0055 / 0.0051
[1/5][100/469]	Loss_D: 0.0100	Loss_G: 5.5080	D(x): 0.9941	D(G(z)): 0.0041 / 0.0050
[1/5][150/469]	Loss_D: 0.3127	Loss_G: 2.7590	D(x): 0.8742	D(G(z)): 0.1485 / 0.0845
[1/5][200/469]	Loss_D: 0.0488	Loss_G: 4.4514	D(x): 0.9753	D(G(z)): 0.0230 / 0.0156
[1/5][250/469]	Loss_D: 0.0232	Loss_G: 5.0328	D(x): 0.9887	D(G(z)): 0.0116 / 0.0090
[1/5][300/469]	Loss_D: 0.0370	Loss_G: 4.4716	D(x): 0.9753	D(G(z)): 0.0115 / 0.0145
[1/5][350/469]	Loss_D: 0.1731	Loss_G: 2.9877	D(x): 0.9259	D(G(z)): 0.0871 / 0.0616
[1/5][400/469]	Loss_D: 0.0611	Loss_G: 4.2089	D(x): 0.9773	D(G(z)): 0.0366 / 0.0168
[1/5][450/469]	Loss_D: 0.0835	Loss_G: 3.8094	D(x): 0.9595	D(G(z)): 0.0398 / 0.0275
[2/5][0/469]	Loss_D: 0.0563	Loss_G: 4.0580	D(x): 0.9704	D(G(z)): 0.0255 / 0.0200
[2/5][50/469]	Loss_D: 0.1323	Loss_G: 3.0354	D(x): 0.9465	D(G(z)): 0.0722 / 0.0561
[2/5][100/469]	Loss_D: 0.1094	Loss_G: 2.8894	D(x): 0.9322	D(G(z)): 0.0350 / 0.0672
[2/5][150/469]	Loss_D: 0.0991	Loss_G: 3.5336	D(x): 0.9681	D(G(z)): 0.0635 / 0.0340
[2/5][200/469]	Loss_D: 0.1453	Loss_G: 3.0698	D(x): 0.9375	D(G(z)): 0.0739 / 0.0584
[2/5][250/469]	Loss_D: 0.1785	Loss_G: 3.7610	D(x): 0.9691	D(G(z)): 0.1325 / 0.0277
[2/5][300/469]	Loss_D: 0.2715	Loss_G: 2.9310	D(x): 0.9402	D(G(z)): 0.1835 / 0.0619
[2/5][350/469]	Loss_D: 0.2133	Loss_G: 2.6066	D(x): 0.8723	D(G(z)): 0.0688 / 0.0834
[2/5][400/469]	Loss_D: 0.3833	Loss_G: 1.7446	D(x): 0.8857	D(G(z)): 0.2199 / 0.1939
[2/5][450/469]	Loss_D: 0.2304	Loss_G: 2.5357	D(x): 0.9016	D(G(z)): 0.1156 / 0.0881
[3/5][0/469]	Loss_D: 1.1075	Loss_G: 0.4192	D(x): 0.3838	D(G(z)): 0.0098 / 0.6692
[3/5][50/469]	Loss_D: 0.2231	Loss_G: 2.4358	D(x): 0.8934	D(G(z)): 0.1002 / 0.1008
[3/5][100/469]	Loss_D: 0.3565	Loss_G: 3.3957	D(x): 0.9386	D(G(z)): 0.2464 / 0.0380
[3/5][150/469]	Loss_D: 0.3635	Loss_G: 2.6130	D(x): 0.9048	D(G(z)): 0.2225 / 0.0841
[3/5][200/469]	Loss_D: 0.3656	Loss_G: 1.6708	D(x): 0.7799	D(G(z)): 0.0928 / 0.2156
[3/5][250/469]	Loss_D: 0.4319	Loss_G: 1.6588	D(x): 0.7227	D(G(z)): 0.0772 / 0.2111
[3/5][300/469]	Loss_D: 0.5896	Loss_G: 1.3397	D(x): 0.9130	D(G(z)): 0.3684 / 0.2916
[3/5][350/469]	Loss_D: 0.3948	Loss_G: 1.6931	D(x): 0.8009	D(G(z)): 0.1449 / 0.2111
[3/5][400/469]	Loss_D: 0.9616	Loss_G: 4.6688	D(x): 0.9684	D(G(z)): 0.5738 / 0.0119
[3/5][450/469]	Loss_D: 0.4665	Loss_G: 1.7228	D(x): 0.7051	D(G(z)): 0.0841 / 0.2053
[4/5][0/469]	Loss_D: 0.3386	Loss_G: 2.5990	D(x): 0.8898	D(G(z)): 0.1873 / 0.0863
[4/5][50/469]	Loss_D: 0.4121	Loss_G: 1.7139	D(x): 0.7414	D(G(z)): 0.0836 / 0.2097
[4/5][100/469]	Loss_D: 0.4266	Loss_G: 2.1115	D(x): 0.8562	D(G(z)): 0.2215 / 0.1389
[4/5][150/469]	Loss_D: 0.3552	Loss_G: 1.9983	D(x): 0.8367	D(G(z)): 0.1506 / 0.1587
[4/5][200/469]	Loss_D: 0.4002	Loss_G: 1.9206	D(x): 0.8362	D(G(z)): 0.1868 / 0.1654
[4/5][250/469]	Loss_D: 0.3310	Loss_G: 1.8705	D(x): 0.8673	D(G(z)): 0.1612 / 0.1773
[4/5][300/469]	Loss_D: 0.4438	Loss_G: 3.5810	D(x): 0.8846	D(G(z)): 0.2583 / 0.0340
[4/5][350/469]	Loss_D: 0.4048	Loss_G: 1.7204	D(x): 0.7665	D(G(z)): 0.1109 / 0.2074
[4/5][400/469]	Loss_D: 0.4599	Loss_G: 1.8826	D(x): 0.8392	D(G(z)): 0.2288 / 0.1814
[4/5][450/469]	Loss_D: 0.4831	Loss_G: 1.9424	D(x): 0.9369	D(G(z)): 0.3246 / 0.1666

Visualization of the results¶

In [4]:
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])

# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)

Task 1.1 Ablation study on batch normalization¶

  1. Please modify the code provided in the Task 1.0 so that the neural network architure does not contain any batch normalization layer.

Hint: modify the *Architectural design for generator and discriminator* section in Task 1.0 2. Train the model with modified networks and visualize the results.

In [5]:
# Generator Code
class Generator_woBN(nn.Module):
    def __init__(self, ngpu):
        super(Generator_woBN, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            ################################ YOUR CODE ################################
            # input is Z, going into a convolution, state size. nz x 1 x 1
            nn.ConvTranspose2d( nz, ngf * 4, kernel_size=4, stride=1, padding=0, bias=False),
            #nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True), # inplace ReLU
            # current state size. (ngf*4) x 4 x 4
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            #nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # current state size. (ngf*2) x 8 x 8
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            #nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # current state size. ngf x 16 x 16
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            # current state size. nc x 32 x 32 
            # Produce number between -1 and 1, as pixel values have been normalized to be between -1 and 1
            nn.Tanh()
            ############################# END YOUR CODE ##############################
        )

    def forward(self, input):
        return self.main(input)


class Discriminator_woBN(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator_woBN, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            ################################ YOUR CODE ################################
            # input is (nc) x 32 x 32 
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 16 x 16
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            #nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 8 x 8 
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            #nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 4 x 4
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            # state size. (ndf*4) x 1 x 1
            nn.Sigmoid()  # Produce probability
            ############################# END YOUR CODE ##############################
        )

    def forward(self, input):
        return self.main(input)

netG_noBN = initialize_net(Generator_woBN, weights_init, device, ngpu)
netD_noBN = initialize_net(Discriminator_woBN, weights_init, device, ngpu)
Generator_woBN(
  (main): Sequential(
    (0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): ReLU(inplace=True)
    (2): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): ReLU(inplace=True)
    (4): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): Tanh()
  )
)
Discriminator_woBN(
  (main): Sequential(
    (0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (7): Sigmoid()
  )
)
In [6]:
# Setup Adam optimizers for both G and D
optimizerD_noBN = optim.Adam(netD_noBN.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG_noBN = optim.Adam(netG_noBN.parameters(), lr=lr, betas=(beta1, 0.999))

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD_noBN.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD_noBN(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG_noBN(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD_noBN(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD_noBN.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG_noBN.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD_noBN(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG_noBN.step()
        
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG_noBN(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
        iters += 1
Starting Training Loop...
[0/5][0/469]	Loss_D: 1.3862	Loss_G: 0.6931	D(x): 0.5001	D(G(z)): 0.5000 / 0.5000
[0/5][50/469]	Loss_D: 1.1952	Loss_G: 0.4654	D(x): 0.8301	D(G(z)): 0.6348 / 0.6280
[0/5][100/469]	Loss_D: 0.8018	Loss_G: 0.9937	D(x): 0.7312	D(G(z)): 0.3840 / 0.3707
[0/5][150/469]	Loss_D: 0.7584	Loss_G: 0.6824	D(x): 0.9520	D(G(z)): 0.5046 / 0.5058
[0/5][200/469]	Loss_D: 1.0215	Loss_G: 0.5153	D(x): 0.9189	D(G(z)): 0.6033 / 0.5978
[0/5][250/469]	Loss_D: 1.2340	Loss_G: 0.4860	D(x): 0.7746	D(G(z)): 0.6177 / 0.6153
[0/5][300/469]	Loss_D: 0.9597	Loss_G: 0.6933	D(x): 0.7836	D(G(z)): 0.4970 / 0.5001
[0/5][350/469]	Loss_D: 0.6789	Loss_G: 0.8237	D(x): 0.9144	D(G(z)): 0.4391 / 0.4392
[0/5][400/469]	Loss_D: 0.6653	Loss_G: 0.7524	D(x): 0.9736	D(G(z)): 0.4712 / 0.4715
[0/5][450/469]	Loss_D: 0.6898	Loss_G: 0.7107	D(x): 0.9877	D(G(z)): 0.4918 / 0.4914
[1/5][0/469]	Loss_D: 0.7180	Loss_G: 0.6825	D(x): 0.9883	D(G(z)): 0.5063 / 0.5055
[1/5][50/469]	Loss_D: 0.7244	Loss_G: 0.6933	D(x): 0.9720	D(G(z)): 0.4993 / 0.5000
[1/5][100/469]	Loss_D: 0.6929	Loss_G: 0.7126	D(x): 0.9869	D(G(z)): 0.4922 / 0.4904
[1/5][150/469]	Loss_D: 0.6935	Loss_G: 0.6994	D(x): 0.9958	D(G(z)): 0.4979 / 0.4969
[1/5][200/469]	Loss_D: 0.7107	Loss_G: 0.6811	D(x): 0.9951	D(G(z)): 0.5060 / 0.5060
[1/5][250/469]	Loss_D: 0.7192	Loss_G: 0.6702	D(x): 0.9979	D(G(z)): 0.5118 / 0.5116
[1/5][300/469]	Loss_D: 0.7956	Loss_G: 0.6045	D(x): 0.9971	D(G(z)): 0.5473 / 0.5464
[1/5][350/469]	Loss_D: 0.8436	Loss_G: 0.5761	D(x): 0.9900	D(G(z)): 0.5653 / 0.5622
[1/5][400/469]	Loss_D: 0.7847	Loss_G: 0.6447	D(x): 0.9692	D(G(z)): 0.5290 / 0.5249
[1/5][450/469]	Loss_D: 0.7689	Loss_G: 0.6764	D(x): 0.9510	D(G(z)): 0.5100 / 0.5084
[2/5][0/469]	Loss_D: 0.7767	Loss_G: 0.6606	D(x): 0.9559	D(G(z)): 0.5173 / 0.5166
[2/5][50/469]	Loss_D: 0.8872	Loss_G: 0.5680	D(x): 0.9541	D(G(z)): 0.5677 / 0.5668
[2/5][100/469]	Loss_D: 0.7708	Loss_G: 0.6685	D(x): 0.9536	D(G(z)): 0.5144 / 0.5125
[2/5][150/469]	Loss_D: 0.7843	Loss_G: 0.6686	D(x): 0.9519	D(G(z)): 0.5197 / 0.5128
[2/5][200/469]	Loss_D: 0.8189	Loss_G: 0.6323	D(x): 0.9567	D(G(z)): 0.5382 / 0.5317
[2/5][250/469]	Loss_D: 0.7841	Loss_G: 0.6547	D(x): 0.9567	D(G(z)): 0.5212 / 0.5200
[2/5][300/469]	Loss_D: 0.7367	Loss_G: 0.7039	D(x): 0.9652	D(G(z)): 0.5031 / 0.4947
[2/5][350/469]	Loss_D: 0.8539	Loss_G: 0.5959	D(x): 0.9603	D(G(z)): 0.5551 / 0.5513
[2/5][400/469]	Loss_D: 0.9855	Loss_G: 0.5762	D(x): 0.9040	D(G(z)): 0.5759 / 0.5664
[2/5][450/469]	Loss_D: 1.0334	Loss_G: 0.5133	D(x): 0.9099	D(G(z)): 0.6051 / 0.5999
[3/5][0/469]	Loss_D: 1.1025	Loss_G: 0.4614	D(x): 0.9130	D(G(z)): 0.6339 / 0.6310
[3/5][50/469]	Loss_D: 1.0680	Loss_G: 0.5253	D(x): 0.8583	D(G(z)): 0.5949 / 0.5919
[3/5][100/469]	Loss_D: 0.9963	Loss_G: 0.6087	D(x): 0.8461	D(G(z)): 0.5567 / 0.5451
[3/5][150/469]	Loss_D: 0.8960	Loss_G: 0.6370	D(x): 0.8868	D(G(z)): 0.5346 / 0.5294
[3/5][200/469]	Loss_D: 0.9942	Loss_G: 0.5689	D(x): 0.8718	D(G(z)): 0.5720 / 0.5668
[3/5][250/469]	Loss_D: 0.9826	Loss_G: 0.5894	D(x): 0.8683	D(G(z)): 0.5663 / 0.5554
[3/5][300/469]	Loss_D: 0.9445	Loss_G: 0.5905	D(x): 0.8849	D(G(z)): 0.5583 / 0.5546
[3/5][350/469]	Loss_D: 0.7998	Loss_G: 0.7788	D(x): 0.8914	D(G(z)): 0.4913 / 0.4596
[3/5][400/469]	Loss_D: 0.7953	Loss_G: 0.7811	D(x): 0.8861	D(G(z)): 0.4865 / 0.4592
[3/5][450/469]	Loss_D: 0.7482	Loss_G: 0.7471	D(x): 0.9025	D(G(z)): 0.4673 / 0.4743
[4/5][0/469]	Loss_D: 0.8973	Loss_G: 0.6918	D(x): 0.8655	D(G(z)): 0.5208 / 0.5015
[4/5][50/469]	Loss_D: 0.8042	Loss_G: 0.6950	D(x): 0.8660	D(G(z)): 0.4763 / 0.4996
[4/5][100/469]	Loss_D: 0.7752	Loss_G: 0.7667	D(x): 0.8943	D(G(z)): 0.4820 / 0.4648
[4/5][150/469]	Loss_D: 0.7470	Loss_G: 0.7153	D(x): 0.9511	D(G(z)): 0.5003 / 0.4895
[4/5][200/469]	Loss_D: 0.7707	Loss_G: 0.6881	D(x): 0.9481	D(G(z)): 0.5111 / 0.5028
[4/5][250/469]	Loss_D: 0.7897	Loss_G: 0.7147	D(x): 0.9079	D(G(z)): 0.4980 / 0.4898
[4/5][300/469]	Loss_D: 0.7853	Loss_G: 0.6860	D(x): 0.9417	D(G(z)): 0.5138 / 0.5042
[4/5][350/469]	Loss_D: 0.7138	Loss_G: 0.7425	D(x): 0.9137	D(G(z)): 0.4608 / 0.4764
[4/5][400/469]	Loss_D: 0.7576	Loss_G: 0.7259	D(x): 0.9282	D(G(z)): 0.4929 / 0.4842
[4/5][450/469]	Loss_D: 0.7458	Loss_G: 0.7083	D(x): 0.9530	D(G(z)): 0.5004 / 0.4928
In [7]:
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])

# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)

Task 1.2 Ablation study on the trick: "Construct different mini-batches for real and fake"¶

  1. Please modify the code provided in the Task 1.0 so that the discriminator algorithm part computes the forward and backward pass for fake and real images concatenated together (with their corresponding fake and real labels concatenated as well) instead of computing the forward and backward passes for fake and real images separately.

Hint: modify the *Loss function and Training function* section in Task 1.0. 2. Train the model with modified networks and visualize the results.

In [8]:
# re-initilizate networks for the generator and discrimintor.
netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator, weights_init, device, ngpu)

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################

        ################################ YOUR CODE ################################

        ## Get all-real half-batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label_real = torch.full((b_size,), real_label, device=device)

        ## Get all-fake half-batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label_fake = torch.full((b_size,), fake_label, device=device)

        ## Combine both half-batches
        images = torch.cat((real_cpu, fake.detach()), dim = 0)
        label = torch.cat((label_real, label_fake), dim = 0)
        # Forward pass full batch through D
        output = netD(images).view(-1)
        # Calculate loss
        errD = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD.backward()
        D_x = output.mean().item()
        # Update D
        optimizerD.step()

        ############################ END YOUR CODE ##############################

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label = torch.full((b_size,), real_label, device=device)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(G(z)): %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_G_z2))
        
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
        iters += 1
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): Tanh()
  )
)
Discriminator(
  (main): Sequential(
    (0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (9): Sigmoid()
  )
)
Starting Training Loop...
[0/5][0/469]	Loss_D: 0.7323	Loss_G: 0.7700	D(G(z)): 0.4673
[0/5][50/469]	Loss_D: 0.0267	Loss_G: 0.1055	D(G(z)): 0.9002
[0/5][100/469]	Loss_D: 0.0101	Loss_G: 0.0410	D(G(z)): 0.9598
[0/5][150/469]	Loss_D: 0.0059	Loss_G: 0.0247	D(G(z)): 0.9756
[0/5][200/469]	Loss_D: 0.0038	Loss_G: 0.0168	D(G(z)): 0.9833
[0/5][250/469]	Loss_D: 0.0027	Loss_G: 0.0124	D(G(z)): 0.9877
[0/5][300/469]	Loss_D: 0.0020	Loss_G: 0.0094	D(G(z)): 0.9906
[0/5][350/469]	Loss_D: 0.0016	Loss_G: 0.0074	D(G(z)): 0.9926
[0/5][400/469]	Loss_D: 0.0013	Loss_G: 0.0060	D(G(z)): 0.9940
[0/5][450/469]	Loss_D: 0.0011	Loss_G: 0.0048	D(G(z)): 0.9952
[1/5][0/469]	Loss_D: 0.0010	Loss_G: 0.0046	D(G(z)): 0.9954
[1/5][50/469]	Loss_D: 0.0008	Loss_G: 0.0039	D(G(z)): 0.9962
[1/5][100/469]	Loss_D: 0.0007	Loss_G: 0.0034	D(G(z)): 0.9966
[1/5][150/469]	Loss_D: 0.0006	Loss_G: 0.0030	D(G(z)): 0.9970
[1/5][200/469]	Loss_D: 0.0005	Loss_G: 0.0027	D(G(z)): 0.9973
[1/5][250/469]	Loss_D: 0.0005	Loss_G: 0.0023	D(G(z)): 0.9977
[1/5][300/469]	Loss_D: 0.0004	Loss_G: 0.0022	D(G(z)): 0.9978
[1/5][350/469]	Loss_D: 0.0004	Loss_G: 0.0019	D(G(z)): 0.9981
[1/5][400/469]	Loss_D: 0.0003	Loss_G: 0.0018	D(G(z)): 0.9982
[1/5][450/469]	Loss_D: 0.0003	Loss_G: 0.0016	D(G(z)): 0.9984
[2/5][0/469]	Loss_D: 0.0003	Loss_G: 0.0016	D(G(z)): 0.9984
[2/5][50/469]	Loss_D: 0.0003	Loss_G: 0.0015	D(G(z)): 0.9985
[2/5][100/469]	Loss_D: 0.0002	Loss_G: 0.0014	D(G(z)): 0.9986
[2/5][150/469]	Loss_D: 0.0002	Loss_G: 0.0012	D(G(z)): 0.9988
[2/5][200/469]	Loss_D: 0.0002	Loss_G: 0.0012	D(G(z)): 0.9988
[2/5][250/469]	Loss_D: 0.0002	Loss_G: 0.0011	D(G(z)): 0.9989
[2/5][300/469]	Loss_D: 0.0002	Loss_G: 0.0010	D(G(z)): 0.9990
[2/5][350/469]	Loss_D: 0.0002	Loss_G: 0.0010	D(G(z)): 0.9990
[2/5][400/469]	Loss_D: 0.0002	Loss_G: 0.0009	D(G(z)): 0.9991
[2/5][450/469]	Loss_D: 0.0001	Loss_G: 0.0009	D(G(z)): 0.9991
[3/5][0/469]	Loss_D: 0.0001	Loss_G: 0.0008	D(G(z)): 0.9992
[3/5][50/469]	Loss_D: 0.0001	Loss_G: 0.0008	D(G(z)): 0.9992
[3/5][100/469]	Loss_D: 0.0001	Loss_G: 0.0008	D(G(z)): 0.9992
[3/5][150/469]	Loss_D: 0.0001	Loss_G: 0.0007	D(G(z)): 0.9993
[3/5][200/469]	Loss_D: 0.0001	Loss_G: 0.0007	D(G(z)): 0.9993
[3/5][250/469]	Loss_D: 0.0001	Loss_G: 0.0006	D(G(z)): 0.9994
[3/5][300/469]	Loss_D: 0.0001	Loss_G: 0.0006	D(G(z)): 0.9994
[3/5][350/469]	Loss_D: 0.0001	Loss_G: 0.0006	D(G(z)): 0.9994
[3/5][400/469]	Loss_D: 0.0001	Loss_G: 0.0005	D(G(z)): 0.9995
[3/5][450/469]	Loss_D: 0.0001	Loss_G: 0.0005	D(G(z)): 0.9995
[4/5][0/469]	Loss_D: 0.0001	Loss_G: 0.0005	D(G(z)): 0.9995
[4/5][50/469]	Loss_D: 0.0001	Loss_G: 0.0005	D(G(z)): 0.9995
[4/5][100/469]	Loss_D: 0.0001	Loss_G: 0.0005	D(G(z)): 0.9995
[4/5][150/469]	Loss_D: 0.0001	Loss_G: 0.0005	D(G(z)): 0.9995
[4/5][200/469]	Loss_D: 0.0001	Loss_G: 0.0004	D(G(z)): 0.9996
[4/5][250/469]	Loss_D: 0.0001	Loss_G: 0.0004	D(G(z)): 0.9996
[4/5][300/469]	Loss_D: 0.0001	Loss_G: 0.0004	D(G(z)): 0.9996
[4/5][350/469]	Loss_D: 0.0001	Loss_G: 0.0004	D(G(z)): 0.9996
[4/5][400/469]	Loss_D: 0.0001	Loss_G: 0.0004	D(G(z)): 0.9996
[4/5][450/469]	Loss_D: 0.0001	Loss_G: 0.0003	D(G(z)): 0.9997
In [9]:
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])

# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)

Task 1.3 Ablation study on the generator's loss function¶

  1. Please modify the code provided in the Task 1.0 so that the Generator algorithm part minimizes $\log(1-D(G(z)))$ instead of the modified loss function suggested in the original GAN paper of $-\log(D(G(z)))$.
    1. Modify the *Loss function and Training function* section in Task 1.0
    2. (Hint) Try to understand the definition of BCE loss first and how the modified loss function was implemented.
  2. Train the model with modified networks and visualize the results.
In [10]:
# re-initilizate networks for the generator and discrimintor.
netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator, weights_init, device, ngpu)

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network
        ###########################

        ################################ YOUR CODE ################################

        netG.zero_grad()
        label.fill_(1.0 - real_label)  # 1-D(G(z)) for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = -criterion(output, label) # -(-log) == +log
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        ############################ END YOUR CODE ##############################

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
        iters += 1
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): Tanh()
  )
)
Discriminator(
  (main): Sequential(
    (0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (9): Sigmoid()
  )
)
Starting Training Loop...
[0/5][0/469]	Loss_D: 1.4907	Loss_G: -0.7538	D(x): 0.5020	D(G(z)): 0.5416 / 0.5226
[0/5][50/469]	Loss_D: 0.5491	Loss_G: -0.2844	D(x): 0.8204	D(G(z)): 0.2900 / 0.2448
[0/5][100/469]	Loss_D: 0.1403	Loss_G: -0.0741	D(x): 0.9484	D(G(z)): 0.0824 / 0.0711
[0/5][150/469]	Loss_D: 0.0854	Loss_G: -0.0412	D(x): 0.9664	D(G(z)): 0.0487 / 0.0401
[0/5][200/469]	Loss_D: 0.0195	Loss_G: -0.0105	D(x): 0.9915	D(G(z)): 0.0109 / 0.0104
[0/5][250/469]	Loss_D: 0.0174	Loss_G: -0.0097	D(x): 0.9930	D(G(z)): 0.0103 / 0.0096
[0/5][300/469]	Loss_D: 0.0191	Loss_G: -0.0115	D(x): 0.9942	D(G(z)): 0.0132 / 0.0114
[0/5][350/469]	Loss_D: 0.0146	Loss_G: -0.0081	D(x): 0.9936	D(G(z)): 0.0081 / 0.0081
[0/5][400/469]	Loss_D: 0.0097	Loss_G: -0.0048	D(x): 0.9951	D(G(z)): 0.0048 / 0.0048
[0/5][450/469]	Loss_D: 0.0055	Loss_G: -0.0029	D(x): 0.9976	D(G(z)): 0.0031 / 0.0028
[1/5][0/469]	Loss_D: 0.0053	Loss_G: -0.0026	D(x): 0.9973	D(G(z)): 0.0026 / 0.0026
[1/5][50/469]	Loss_D: 0.0051	Loss_G: -0.0024	D(x): 0.9972	D(G(z)): 0.0022 / 0.0024
[1/5][100/469]	Loss_D: 0.0038	Loss_G: -0.0019	D(x): 0.9981	D(G(z)): 0.0019 / 0.0019
[1/5][150/469]	Loss_D: 0.0034	Loss_G: -0.0018	D(x): 0.9985	D(G(z)): 0.0019 / 0.0018
[1/5][200/469]	Loss_D: 0.0029	Loss_G: -0.0015	D(x): 0.9985	D(G(z)): 0.0015 / 0.0015
[1/5][250/469]	Loss_D: 0.0023	Loss_G: -0.0013	D(x): 0.9990	D(G(z)): 0.0013 / 0.0013
[1/5][300/469]	Loss_D: 0.0030	Loss_G: -0.0014	D(x): 0.9983	D(G(z)): 0.0012 / 0.0014
[1/5][350/469]	Loss_D: 0.0018	Loss_G: -0.0010	D(x): 0.9991	D(G(z)): 0.0010 / 0.0010
[1/5][400/469]	Loss_D: 0.0016	Loss_G: -0.0008	D(x): 0.9993	D(G(z)): 0.0009 / 0.0008
[1/5][450/469]	Loss_D: 0.0015	Loss_G: -0.0008	D(x): 0.9993	D(G(z)): 0.0008 / 0.0008
[2/5][0/469]	Loss_D: 0.0012	Loss_G: -0.0007	D(x): 0.9995	D(G(z)): 0.0007 / 0.0007
[2/5][50/469]	Loss_D: 0.0014	Loss_G: -0.0006	D(x): 0.9993	D(G(z)): 0.0006 / 0.0006
[2/5][100/469]	Loss_D: 0.0010	Loss_G: -0.0005	D(x): 0.9995	D(G(z)): 0.0005 / 0.0005
[2/5][150/469]	Loss_D: 0.0009	Loss_G: -0.0005	D(x): 0.9996	D(G(z)): 0.0005 / 0.0005
[2/5][200/469]	Loss_D: 0.0010	Loss_G: -0.0005	D(x): 0.9994	D(G(z)): 0.0005 / 0.0005
[2/5][250/469]	Loss_D: 0.0007	Loss_G: -0.0004	D(x): 0.9997	D(G(z)): 0.0004 / 0.0004
[2/5][300/469]	Loss_D: 0.0008	Loss_G: -0.0004	D(x): 0.9996	D(G(z)): 0.0004 / 0.0004
[2/5][350/469]	Loss_D: 0.0006	Loss_G: -0.0003	D(x): 0.9997	D(G(z)): 0.0003 / 0.0003
[2/5][400/469]	Loss_D: 0.0007	Loss_G: -0.0003	D(x): 0.9996	D(G(z)): 0.0003 / 0.0003
[2/5][450/469]	Loss_D: 0.0006	Loss_G: -0.0003	D(x): 0.9997	D(G(z)): 0.0003 / 0.0003
[3/5][0/469]	Loss_D: 0.0005	Loss_G: -0.0003	D(x): 0.9998	D(G(z)): 0.0003 / 0.0003
[3/5][50/469]	Loss_D: 0.0005	Loss_G: -0.0003	D(x): 0.9998	D(G(z)): 0.0003 / 0.0003
[3/5][100/469]	Loss_D: 0.0005	Loss_G: -0.0003	D(x): 0.9998	D(G(z)): 0.0003 / 0.0003
[3/5][150/469]	Loss_D: 0.0005	Loss_G: -0.0002	D(x): 0.9997	D(G(z)): 0.0002 / 0.0002
[3/5][200/469]	Loss_D: 0.0004	Loss_G: -0.0002	D(x): 0.9998	D(G(z)): 0.0002 / 0.0002
[3/5][250/469]	Loss_D: 0.0004	Loss_G: -0.0002	D(x): 0.9998	D(G(z)): 0.0002 / 0.0002
[3/5][300/469]	Loss_D: 0.0003	Loss_G: -0.0002	D(x): 0.9999	D(G(z)): 0.0002 / 0.0002
[3/5][350/469]	Loss_D: 0.0004	Loss_G: -0.0002	D(x): 0.9998	D(G(z)): 0.0002 / 0.0002
[3/5][400/469]	Loss_D: 0.0003	Loss_G: -0.0002	D(x): 0.9999	D(G(z)): 0.0002 / 0.0002
[3/5][450/469]	Loss_D: 0.0003	Loss_G: -0.0002	D(x): 0.9999	D(G(z)): 0.0002 / 0.0002
[4/5][0/469]	Loss_D: 0.0003	Loss_G: -0.0002	D(x): 0.9998	D(G(z)): 0.0002 / 0.0002
[4/5][50/469]	Loss_D: 0.0003	Loss_G: -0.0002	D(x): 0.9999	D(G(z)): 0.0002 / 0.0002
[4/5][100/469]	Loss_D: 0.0003	Loss_G: -0.0002	D(x): 0.9999	D(G(z)): 0.0002 / 0.0002
[4/5][150/469]	Loss_D: 0.0002	Loss_G: -0.0001	D(x): 0.9999	D(G(z)): 0.0001 / 0.0001
[4/5][200/469]	Loss_D: 0.0002	Loss_G: -0.0001	D(x): 0.9999	D(G(z)): 0.0001 / 0.0001
[4/5][250/469]	Loss_D: 0.0002	Loss_G: -0.0001	D(x): 0.9999	D(G(z)): 0.0001 / 0.0001
[4/5][300/469]	Loss_D: 0.0002	Loss_G: -0.0001	D(x): 0.9999	D(G(z)): 0.0001 / 0.0001
[4/5][350/469]	Loss_D: 0.0002	Loss_G: -0.0001	D(x): 0.9999	D(G(z)): 0.0001 / 0.0001
[4/5][400/469]	Loss_D: 0.0002	Loss_G: -0.0001	D(x): 0.9999	D(G(z)): 0.0001 / 0.0001
[4/5][450/469]	Loss_D: 0.0002	Loss_G: -0.0001	D(x): 0.9999	D(G(z)): 0.0001 / 0.0001
In [11]:
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])

# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)

Task 1.4 Ablation study on the weight initialization¶

  1. Please use the function initialize_net provided in Task 1.0 to initialize the generator and discriminator function without weight initialization (HINT: There is no need to modify the code for initialize_net function).
  2. Train the model with modified networks and visualize the results.
In [12]:
################################ YOUR CODE ################################
netG_woinit = initialize_net(Generator, None, device, ngpu)
netD_woinit = initialize_net(Discriminator, None, device, ngpu)\
###########################  END YOUR CODE ###############################
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): Tanh()
  )
)
Discriminator(
  (main): Sequential(
    (0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (9): Sigmoid()
  )
)
In [13]:
# Setup Adam optimizers for both G and D
optimizerD_woinit = optim.Adam(netD_woinit.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG_woinit = optim.Adam(netG_woinit.parameters(), lr=lr, betas=(beta1, 0.999))

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD_woinit.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD_woinit(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG_woinit(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD_woinit(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD_woinit.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG_woinit.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD_woinit(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG_woinit.step()
        
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG_woinit(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
        iters += 1
Starting Training Loop...
[0/5][0/469]	Loss_D: 1.4309	Loss_G: 0.6981	D(x): 0.5045	D(G(z)): 0.5139 / 0.5034
[0/5][50/469]	Loss_D: 0.6675	Loss_G: 1.1118	D(x): 0.8094	D(G(z)): 0.3605 / 0.3360
[0/5][100/469]	Loss_D: 0.5075	Loss_G: 1.5336	D(x): 0.8203	D(G(z)): 0.2553 / 0.2255
[0/5][150/469]	Loss_D: 0.3651	Loss_G: 1.9577	D(x): 0.8672	D(G(z)): 0.1912 / 0.1488
[0/5][200/469]	Loss_D: 0.2252	Loss_G: 2.5734	D(x): 0.9104	D(G(z)): 0.1176 / 0.0874
[0/5][250/469]	Loss_D: 0.1674	Loss_G: 3.0178	D(x): 0.9270	D(G(z)): 0.0832 / 0.0546
[0/5][300/469]	Loss_D: 0.0663	Loss_G: 3.8738	D(x): 0.9726	D(G(z)): 0.0366 / 0.0244
[0/5][350/469]	Loss_D: 0.0798	Loss_G: 4.2964	D(x): 0.9673	D(G(z)): 0.0436 / 0.0177
[0/5][400/469]	Loss_D: 0.0458	Loss_G: 4.1492	D(x): 0.9835	D(G(z)): 0.0283 / 0.0179
[0/5][450/469]	Loss_D: 0.0921	Loss_G: 3.8244	D(x): 0.9627	D(G(z)): 0.0500 / 0.0274
[1/5][0/469]	Loss_D: 0.0434	Loss_G: 4.1880	D(x): 0.9861	D(G(z)): 0.0286 / 0.0178
[1/5][50/469]	Loss_D: 0.0449	Loss_G: 4.6013	D(x): 0.9786	D(G(z)): 0.0227 / 0.0119
[1/5][100/469]	Loss_D: 0.0507	Loss_G: 4.3958	D(x): 0.9765	D(G(z)): 0.0257 / 0.0148
[1/5][150/469]	Loss_D: 0.0395	Loss_G: 4.4340	D(x): 0.9857	D(G(z)): 0.0244 / 0.0139
[1/5][200/469]	Loss_D: 0.0393	Loss_G: 4.7964	D(x): 0.9733	D(G(z)): 0.0110 / 0.0103
[1/5][250/469]	Loss_D: 0.1721	Loss_G: 6.1910	D(x): 0.8655	D(G(z)): 0.0021 / 0.0031
[1/5][300/469]	Loss_D: 0.0835	Loss_G: 3.9084	D(x): 0.9669	D(G(z)): 0.0459 / 0.0259
[1/5][350/469]	Loss_D: 0.0956	Loss_G: 4.1958	D(x): 0.9626	D(G(z)): 0.0451 / 0.0187
[1/5][400/469]	Loss_D: 0.0348	Loss_G: 4.4310	D(x): 0.9852	D(G(z)): 0.0188 / 0.0142
[1/5][450/469]	Loss_D: 0.0471	Loss_G: 4.3491	D(x): 0.9740	D(G(z)): 0.0191 / 0.0157
[2/5][0/469]	Loss_D: 0.0333	Loss_G: 4.5120	D(x): 0.9898	D(G(z)): 0.0225 / 0.0141
[2/5][50/469]	Loss_D: 0.0354	Loss_G: 4.7844	D(x): 0.9898	D(G(z)): 0.0243 / 0.0116
[2/5][100/469]	Loss_D: 0.0749	Loss_G: 3.9873	D(x): 0.9935	D(G(z)): 0.0632 / 0.0287
[2/5][150/469]	Loss_D: 0.1097	Loss_G: 3.7743	D(x): 0.9328	D(G(z)): 0.0308 / 0.0342
[2/5][200/469]	Loss_D: 0.0985	Loss_G: 3.8010	D(x): 0.9696	D(G(z)): 0.0620 / 0.0335
[2/5][250/469]	Loss_D: 0.1394	Loss_G: 3.5908	D(x): 0.9653	D(G(z)): 0.0922 / 0.0348
[2/5][300/469]	Loss_D: 0.2097	Loss_G: 2.9469	D(x): 0.8464	D(G(z)): 0.0142 / 0.0711
[2/5][350/469]	Loss_D: 0.0952	Loss_G: 3.8474	D(x): 0.9582	D(G(z)): 0.0480 / 0.0299
[2/5][400/469]	Loss_D: 0.5481	Loss_G: 4.6357	D(x): 0.6562	D(G(z)): 0.0076 / 0.0141
[2/5][450/469]	Loss_D: 0.1067	Loss_G: 3.4584	D(x): 0.9367	D(G(z)): 0.0348 / 0.0441
[3/5][0/469]	Loss_D: 0.0826	Loss_G: 3.9818	D(x): 0.9645	D(G(z)): 0.0433 / 0.0278
[3/5][50/469]	Loss_D: 0.0711	Loss_G: 4.2012	D(x): 0.9667	D(G(z)): 0.0346 / 0.0218
[3/5][100/469]	Loss_D: 0.1145	Loss_G: 3.5186	D(x): 0.9422	D(G(z)): 0.0470 / 0.0456
[3/5][150/469]	Loss_D: 0.1437	Loss_G: 3.3226	D(x): 0.9394	D(G(z)): 0.0713 / 0.0527
[3/5][200/469]	Loss_D: 0.1419	Loss_G: 3.6223	D(x): 0.9250	D(G(z)): 0.0551 / 0.0416
[3/5][250/469]	Loss_D: 0.3101	Loss_G: 1.5062	D(x): 0.7856	D(G(z)): 0.0204 / 0.2810
[3/5][300/469]	Loss_D: 0.2577	Loss_G: 4.0582	D(x): 0.9536	D(G(z)): 0.1658 / 0.0291
[3/5][350/469]	Loss_D: 0.2102	Loss_G: 2.1140	D(x): 0.8678	D(G(z)): 0.0437 / 0.1755
[3/5][400/469]	Loss_D: 0.2216	Loss_G: 2.3289	D(x): 0.8570	D(G(z)): 0.0466 / 0.1348
[3/5][450/469]	Loss_D: 0.2906	Loss_G: 1.7698	D(x): 0.7868	D(G(z)): 0.0196 / 0.2229
[4/5][0/469]	Loss_D: 0.1330	Loss_G: 3.7065	D(x): 0.9488	D(G(z)): 0.0726 / 0.0392
[4/5][50/469]	Loss_D: 0.1386	Loss_G: 3.0356	D(x): 0.9192	D(G(z)): 0.0460 / 0.0670
[4/5][100/469]	Loss_D: 0.1912	Loss_G: 2.5466	D(x): 0.8997	D(G(z)): 0.0712 / 0.1071
[4/5][150/469]	Loss_D: 0.8641	Loss_G: 2.1753	D(x): 0.5032	D(G(z)): 0.0050 / 0.1477
[4/5][200/469]	Loss_D: 0.1820	Loss_G: 2.9584	D(x): 0.9013	D(G(z)): 0.0662 / 0.0727
[4/5][250/469]	Loss_D: 0.1569	Loss_G: 2.9308	D(x): 0.9160	D(G(z)): 0.0582 / 0.0707
[4/5][300/469]	Loss_D: 0.1658	Loss_G: 2.8245	D(x): 0.9615	D(G(z)): 0.1122 / 0.0805
[4/5][350/469]	Loss_D: 0.2837	Loss_G: 3.4902	D(x): 0.9602	D(G(z)): 0.1948 / 0.0469
[4/5][400/469]	Loss_D: 0.1965	Loss_G: 2.9231	D(x): 0.9079	D(G(z)): 0.0830 / 0.0779
[4/5][450/469]	Loss_D: 0.2242	Loss_G: 2.6502	D(x): 0.8598	D(G(z)): 0.0458 / 0.1053
In [14]:
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])

# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)

Exercise 2: Implement the WGAN with weight clipping¶

Wasserstein GAN (WGAN) is an alternative training strategy to traditional GAN. WGAN may provide more stable learning and may avoid problems faced in traditional GAN training like mode collapse.

  1. Rewrite the loss functions and training function according to the algorithm introduced in slide 18 in Lecture note for WGAN. A few notes/hints:
    1. Keep the same generator as in Exercise 1, Task 1.0, but modify the discriminator so that there is no restriction on the range of the output. (Simply comment out the last Sigmoid layer)
    2. Modify the optimizer to be the RMSProp optimizer with a learning rate equal to the value in lr_rms (which we set to 5e-4, which is larger than the rate in the paper but works better for our purposes).
    3. Use torch.Tensor.clamp_() function to clip the parameter values. You will need to do this for all parameters of the discriminator. See algorithm for when to do this.
  2. Train the model with modified networks and visualize the results.
In [15]:
class Discriminator_WGAN(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator_WGAN, self).__init__()
        self.ngpu = ngpu
        ################################ YOUR CODE ################################
        self.main = nn.Sequential(
            # input is (nc) x 32 x 32 
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 16 x 16
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 8 x 8 
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 4 x 4
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False)#,
            # state size. (ndf*4) x 1 x 1
            #nn.Sigmoid()  # Produce probability
        )
        ########################### END YOUR CODE ################################

    def forward(self, input):
        return self.main(input)

netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator_WGAN, weights_init, device, ngpu)
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): Tanh()
  )
)
Discriminator_WGAN(
  (main): Sequential(
    (0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)
In [16]:
############################ YOUR CODE ############################
# Setup RMSprop optimizers for both netG and netD with given learning rate as `lr_rms`
optimizerDW = optim.RMSprop(netD.parameters(), lr=lr_rms)
optimizerGW = optim.RMSprop(netG.parameters(), lr=lr_rms)

# Define WGAN loss functions
def WDloss(Dreal, Dfake):
    r = torch.mean(Dreal)
    f = torch.mean(Dfake)
    return(-(r - f))

def WGloss(Dfake):
    return(-torch.mean(Dfake))

######################## # END YOUR CODE ##########################

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
n_critic = 5
c = 0.01
dataloader_iter = iter(dataloader)

print("Starting Training Loop...")
num_iters = 1000

for iters in range(num_iters):
    
    ###########################################################################
    # (1) Train Discriminator more: minimize -(mean(D(real))-mean(D(fake)))
    ###########################################################################

    for p in netD.parameters():
        p.requires_grad = True

    for idx_critic in range(n_critic):

        netD.zero_grad()

        try:
            data = next(dataloader_iter)
        except StopIteration:
            dataloader_iter = iter(dataloader)
            data = next(dataloader_iter)

        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        D_real = netD(real_cpu).view(-1)

        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        D_fake = netD(fake).view(-1)
        
        ############################ YOUR CODE ############################
        # Define your loss function for variable `D_loss`
        D_loss = WDloss(D_real, D_fake)

        # Backpropagate the loss function and upate the optimizer
        D_loss.backward()
        optimizerDW.step()

        # Clip the gradient with limit `c` by using `clamp_()` function
        for p in netD.parameters(): p.data.clamp_(-c, c)

        ######################## # END YOUR CODE ##########################

    ###########################################################################
    # (2) Update G network: minimize -mean(D(fake)) (Update only once in 5 epochs)
    ###########################################################################
    for p in netD.parameters():
        p.requires_grad = False
    
    netG.zero_grad()

    noise = torch.randn(b_size, nz, 1, 1, device=device)
    fake = netG(noise)
    D_fake = netD(fake).view(-1)

    ################################ YOUR CODE ################################
    # Define your loss function for variable `G_loss`
    G_loss = WGloss(D_fake)

    # Backpropagate the loss function and upate the optimizer
    G_loss.backward()
    optimizerGW.step()

    ############################# END YOUR CODE ##############################

    # Output training stats
    if iters % 10 == 0:
        print('[%4d/%4d]   Loss_D: %6.4f    Loss_G: %6.4f'
            % (iters, num_iters, D_loss.item(), G_loss.item()))
    
    # Save Losses for plotting later
    G_losses.append(G_loss.item())
    D_losses.append(D_loss.item())
    
    # Check how the generator is doing by saving G's output on fixed_noise
    if (iters % 100 == 0):
        with torch.no_grad():
            fake = netG(fixed_noise).detach().cpu()
        img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
Starting Training Loop...
[   0/1000]   Loss_D: -0.0013    Loss_G: -0.0001
[  10/1000]   Loss_D: -0.0014    Loss_G: -0.0018
[  20/1000]   Loss_D: -0.0015    Loss_G: -0.0013
[  30/1000]   Loss_D: -0.0019    Loss_G: -0.0010
[  40/1000]   Loss_D: -0.0029    Loss_G: -0.0001
[  50/1000]   Loss_D: -0.0041    Loss_G: -0.0002
[  60/1000]   Loss_D: -0.0057    Loss_G: -0.0003
[  70/1000]   Loss_D: -0.0078    Loss_G: 0.0018
[  80/1000]   Loss_D: -0.0084    Loss_G: -0.0017
[  90/1000]   Loss_D: -0.0074    Loss_G: 0.0038
[ 100/1000]   Loss_D: -0.0081    Loss_G: 0.0072
[ 110/1000]   Loss_D: -0.0056    Loss_G: -0.0038
[ 120/1000]   Loss_D: -0.0073    Loss_G: -0.0009
[ 130/1000]   Loss_D: -0.0077    Loss_G: 0.0036
[ 140/1000]   Loss_D: -0.0062    Loss_G: -0.0024
[ 150/1000]   Loss_D: -0.0062    Loss_G: 0.0046
[ 160/1000]   Loss_D: -0.0066    Loss_G: 0.0009
[ 170/1000]   Loss_D: -0.0057    Loss_G: -0.0024
[ 180/1000]   Loss_D: -0.0062    Loss_G: 0.0039
[ 190/1000]   Loss_D: -0.0059    Loss_G: 0.0043
[ 200/1000]   Loss_D: -0.0050    Loss_G: 0.0083
[ 210/1000]   Loss_D: -0.0052    Loss_G: -0.0004
[ 220/1000]   Loss_D: -0.0042    Loss_G: 0.0100
[ 230/1000]   Loss_D: -0.0049    Loss_G: 0.0021
[ 240/1000]   Loss_D: -0.0051    Loss_G: 0.0045
[ 250/1000]   Loss_D: -0.0046    Loss_G: 0.0004
[ 260/1000]   Loss_D: -0.0051    Loss_G: 0.0118
[ 270/1000]   Loss_D: -0.0043    Loss_G: 0.0053
[ 280/1000]   Loss_D: -0.0031    Loss_G: -0.0029
[ 290/1000]   Loss_D: -0.0033    Loss_G: -0.0044
[ 300/1000]   Loss_D: -0.0043    Loss_G: -0.0019
[ 310/1000]   Loss_D: -0.0049    Loss_G: -0.0011
[ 320/1000]   Loss_D: -0.0037    Loss_G: -0.0048
[ 330/1000]   Loss_D: -0.0036    Loss_G: -0.0003
[ 340/1000]   Loss_D: -0.0034    Loss_G: 0.0031
[ 350/1000]   Loss_D: -0.0032    Loss_G: -0.0023
[ 360/1000]   Loss_D: -0.0035    Loss_G: 0.0034
[ 370/1000]   Loss_D: -0.0033    Loss_G: 0.0029
[ 380/1000]   Loss_D: -0.0035    Loss_G: 0.0054
[ 390/1000]   Loss_D: -0.0036    Loss_G: 0.0003
[ 400/1000]   Loss_D: -0.0035    Loss_G: 0.0029
[ 410/1000]   Loss_D: -0.0044    Loss_G: 0.0120
[ 420/1000]   Loss_D: -0.0023    Loss_G: -0.0062
[ 430/1000]   Loss_D: -0.0030    Loss_G: -0.0057
[ 440/1000]   Loss_D: -0.0032    Loss_G: 0.0005
[ 450/1000]   Loss_D: -0.0015    Loss_G: 0.0035
[ 460/1000]   Loss_D: -0.0020    Loss_G: 0.0060
[ 470/1000]   Loss_D: -0.0026    Loss_G: -0.0006
[ 480/1000]   Loss_D: -0.0036    Loss_G: 0.0053
[ 490/1000]   Loss_D: -0.0026    Loss_G: -0.0024
[ 500/1000]   Loss_D: -0.0030    Loss_G: -0.0015
[ 510/1000]   Loss_D: -0.0022    Loss_G: 0.0010
[ 520/1000]   Loss_D: -0.0025    Loss_G: 0.0019
[ 530/1000]   Loss_D: -0.0023    Loss_G: 0.0011
[ 540/1000]   Loss_D: -0.0026    Loss_G: -0.0007
[ 550/1000]   Loss_D: -0.0018    Loss_G: 0.0041
[ 560/1000]   Loss_D: -0.0024    Loss_G: -0.0030
[ 570/1000]   Loss_D: -0.0021    Loss_G: 0.0047
[ 580/1000]   Loss_D: -0.0024    Loss_G: 0.0007
[ 590/1000]   Loss_D: -0.0026    Loss_G: -0.0001
[ 600/1000]   Loss_D: -0.0022    Loss_G: 0.0034
[ 610/1000]   Loss_D: -0.0027    Loss_G: 0.0011
[ 620/1000]   Loss_D: -0.0024    Loss_G: 0.0061
[ 630/1000]   Loss_D: -0.0022    Loss_G: 0.0058
[ 640/1000]   Loss_D: -0.0023    Loss_G: 0.0060
[ 650/1000]   Loss_D: -0.0023    Loss_G: 0.0013
[ 660/1000]   Loss_D: -0.0025    Loss_G: -0.0010
[ 670/1000]   Loss_D: -0.0020    Loss_G: 0.0011
[ 680/1000]   Loss_D: -0.0023    Loss_G: 0.0022
[ 690/1000]   Loss_D: -0.0025    Loss_G: 0.0008
[ 700/1000]   Loss_D: -0.0027    Loss_G: 0.0076
[ 710/1000]   Loss_D: -0.0025    Loss_G: 0.0052
[ 720/1000]   Loss_D: -0.0021    Loss_G: -0.0001
[ 730/1000]   Loss_D: -0.0020    Loss_G: 0.0040
[ 740/1000]   Loss_D: -0.0025    Loss_G: 0.0069
[ 750/1000]   Loss_D: -0.0020    Loss_G: -0.0000
[ 760/1000]   Loss_D: -0.0025    Loss_G: 0.0016
[ 770/1000]   Loss_D: -0.0021    Loss_G: 0.0011
[ 780/1000]   Loss_D: -0.0020    Loss_G: 0.0037
[ 790/1000]   Loss_D: -0.0018    Loss_G: -0.0030
[ 800/1000]   Loss_D: -0.0025    Loss_G: 0.0008
[ 810/1000]   Loss_D: -0.0016    Loss_G: -0.0004
[ 820/1000]   Loss_D: -0.0020    Loss_G: 0.0038
[ 830/1000]   Loss_D: -0.0023    Loss_G: 0.0036
[ 840/1000]   Loss_D: -0.0017    Loss_G: -0.0000
[ 850/1000]   Loss_D: -0.0025    Loss_G: 0.0049
[ 860/1000]   Loss_D: -0.0020    Loss_G: 0.0027
[ 870/1000]   Loss_D: -0.0029    Loss_G: 0.0075
[ 880/1000]   Loss_D: -0.0019    Loss_G: -0.0039
[ 890/1000]   Loss_D: -0.0016    Loss_G: -0.0016
[ 900/1000]   Loss_D: -0.0013    Loss_G: 0.0020
[ 910/1000]   Loss_D: -0.0022    Loss_G: 0.0028
[ 920/1000]   Loss_D: -0.0016    Loss_G: 0.0032
[ 930/1000]   Loss_D: -0.0017    Loss_G: -0.0013
[ 940/1000]   Loss_D: -0.0016    Loss_G: 0.0013
[ 950/1000]   Loss_D: -0.0014    Loss_G: -0.0019
[ 960/1000]   Loss_D: -0.0018    Loss_G: 0.0001
[ 970/1000]   Loss_D: -0.0018    Loss_G: 0.0020
[ 980/1000]   Loss_D: -0.0014    Loss_G: 0.0007
[ 990/1000]   Loss_D: -0.0019    Loss_G: -0.0000
In [17]:
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])

# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)